package com.raos.websocket.server;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;

import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/**
 * WebSocketServer
 *  因为 WebSocket 是类似客户端服务端的形式(采用 ws 协议)，那么这里的 WebSocketServer 其实就相当于一个 ws 协议的 Controller
 *
 * @author raos
 * @email 991207823@qq.com
 * @date 2021/10/16 10:24
 */
@Component
@ServerEndpoint("/websocket/{userId}")
public class WebSocketServer {
    private static final Logger log = LoggerFactory.getLogger(WebSocketServer.class);

    /** 静态变量，用来记录当前在线连接数。它应该设计成线程安全的。*/
    private static int onlineCount = 0;

    /** 存放每个客户端对应的 MyWebSocket 对象的安全容器*/
    private static ConcurrentHashMap<String, WebSocketServer> webSocketMap = new ConcurrentHashMap<>();

    /** 与某个客户端的连接会话，需要通过它来给客户端发送数据*/
    private Session session;

    /** 接收userId*/
    private String userId = "";

    /**
     * 连接建立成功调用的方法
     */
    @OnOpen
    public void onOpen(Session session, @PathParam("userId") String userId) {
        this.session = session;
        this.userId = userId;
        // 判断这个用户是否已经处在连接中，是重新接入
        if (webSocketMap.containsKey(userId)) {
            webSocketMap.remove(userId);
            webSocketMap.put(userId, this);
        } else {
            webSocketMap.put(userId, this);
            // 加入map中, 在线数加1
            addOnlineCount();
        }
        log.info("当前连接用户ID【{}】，当前在线人数为【{}】" , userId, getOnlineCount());

        try {
            sendMessage("连接成功 ...");
        } catch (IOException e) {
            log.error("用户ID【{}】，网络异常!!! 错误信息【{}】", userId, e.getMessage());
        }
    }

    /**
     * 连接关闭调用的方法
     */
    @OnClose
    public void onClose() {
        if (webSocketMap.containsKey(userId)) {
            webSocketMap.remove(userId);
            // 从set中删除
            subOnlineCount();
        }
        log.info("【{}】用户退出，当前在线人数为【{}】", userId, getOnlineCount());
    }

    /**
     * 收到客户端消息后调用的方法
     * @param message 客户端发送过来的消息
     */
    @OnMessage
    public void onMessage(String message, Session session) {
        log.info("【{}】用户消息，报文【{}】",userId ,message);
        // 可以群发消息
        // 消息保存到数据库/redis
        if (StringUtils.isNotBlank(message)) {
            try {
                // 解析发送的报文
                JSONObject jsonObject = JSON.parseObject(message);
                // 追加发送人(防止串改)
                jsonObject.put("fromUserId", this.userId);
                String toUserId = jsonObject.getString("toUserId");
                // 传送给对应toUserId用户的websocket
                if (StringUtils.isNotBlank(toUserId) && webSocketMap.containsKey(toUserId)) {
                    webSocketMap.get(toUserId).sendMessage(jsonObject.toJSONString());
                } else {
                    // 否则不在这个服务器上，发送到mysql或者redis
                    log.error("请求的用户ID【{}】不在该服务器上", toUserId);
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    /**
     * 发送错误调用的方法
     * @param session
     * @param error 错误信息
     */
    @OnError
    public void onError(Session session, Throwable error) {
        try {
            log.error("【{}】用户错误，原因是【{}】", this.userId, error.getMessage());
            error.printStackTrace();
            webSocketMap.remove(this.userId);
            session.close();
        } catch (IOException e) {
            log.error("执行错误关闭异常，原因是【{}】", e.getMessage());
            e.printStackTrace();
        }
    }

    /**
     * 实现服务器主动推送
     */
    public void sendMessage(String message) throws IOException {
        this.session.getBasicRemote().sendText(message);
    }

    /**
     * 发送自定义消息
     */
    public static void sendInfo(String message, @PathParam("userId") String userId) throws IOException {
        log.info("发送消息到:【{}】，报文:【{}】", userId, message);
        if (StringUtils.isNotBlank(userId) && webSocketMap.containsKey(userId)) {
            webSocketMap.get(userId).sendMessage(message);
        } else {
            log.error("用户【{}】不在线！", userId);
        }
    }

    /**
     * 广播消息
     */
    public static void sendAll(String msg) throws IOException {
        log.info("批量发送消息，报文:【{}】", msg);
        Set<Map.Entry<String, WebSocketServer>> entrySet = webSocketMap.entrySet();
        for (Map.Entry<String, WebSocketServer> entry : entrySet) {
            entry.getValue().sendMessage(msg);
        }
    }

    public static synchronized int getOnlineCount() {
        return onlineCount;
    }

    public static synchronized void addOnlineCount() {
        WebSocketServer.onlineCount++;
    }

    public static synchronized void subOnlineCount() {
        WebSocketServer.onlineCount--;
    }

}
